package pers.mihao.ancient_empire.common.util;

import javassist.*;
import javassist.bytecode.AnnotationsAttribute;
import javassist.bytecode.ClassFile;
import javassist.bytecode.ConstPool;
import javassist.bytecode.annotation.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import pers.mihao.ancient_empire.common.vo.BaseException;

import javax.validation.*;
import java.lang.annotation.Annotation;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.stream.Collectors;

/**
 *
 * @author mihao
 * @version 1.0
 * @date 2021\1\17 0017 19:26
 */
public class ValidateUtil {

	Validator validator = Validation.buildDefaultValidatorFactory().getValidator();

	Logger log = LoggerFactory.getLogger(ValidateUtil.class);

	private static ValidateUtil instance = null;

	private ValidateUtil() {
		if (validator == null) {
			synchronized (ValidateUtil.class) {
				validator = Validation.buildDefaultValidatorFactory().getValidator();
			}
		}
	}

	/**
	 * 验证对象
	 */
	public static void validate(Class clazz, Method method, Object[] args) {
		getInstance().doValidate(clazz, method, args);
	}

	private static ValidateUtil getInstance() {
		if (instance == null) {
			synchronized (ValidateUtil.class) {
				instance = new ValidateUtil();
			}
		}
		return instance;
	}

	private static boolean isPrimitives(Class<?> cls) {
		if (cls.isArray()) {
			return isPrimitive(cls.getComponentType());
		}
		return isPrimitive(cls);
	}

	private static boolean isPrimitive(Class<?> cls) {
		return cls.isPrimitive() || cls == String.class || cls == Boolean.class || cls == Character.class
				|| Number.class.isAssignableFrom(cls) || Date.class.isAssignableFrom(cls);
	}

	private Object getMethodParameterBean(Class<?> clazz, Method method, Object[] args) {
		if (!hasConstraintParameter(method)) {
			return null;
		}
		try {
			String parameterClassName = generateMethodParameterClassName(clazz, method);
			Class<?> parameterClass;
			try {
				parameterClass = Class.forName(parameterClassName, true, clazz.getClassLoader());
			} catch (ClassNotFoundException e) {
				ClassPool pool = ClassPool.getDefault();
				CtClass ctClass = pool.makeClass(parameterClassName);
				ClassFile classFile = ctClass.getClassFile();
				classFile.setVersionToJava5();
				ctClass.addConstructor(CtNewConstructor.defaultConstructor(pool.getCtClass(parameterClassName)));
				// parameter fields
				Class<?>[] parameterTypes = method.getParameterTypes();
				Annotation[][] parameterAnnotations = method.getParameterAnnotations();
				for (int i = 0; i < parameterTypes.length; i++) {
					Class<?> type = parameterTypes[i];
					Annotation[] annotations = parameterAnnotations[i];
					AnnotationsAttribute attribute = new AnnotationsAttribute(classFile.getConstPool(), AnnotationsAttribute.visibleTag);
					for (Annotation annotation : annotations) {
						if (annotation.annotationType().isAnnotationPresent(Constraint.class)) {
							javassist.bytecode.annotation.Annotation ja = new javassist.bytecode.annotation.Annotation(
									classFile.getConstPool(), pool.getCtClass(annotation.annotationType().getName()));
							Method[] members = annotation.annotationType().getMethods();
							for (Method member : members) {
								if (Modifier.isPublic(member.getModifiers())
										&& member.getParameterTypes().length == 0
										&& member.getDeclaringClass() == annotation.annotationType()) {
									Object value = member.invoke(annotation, new Object[0]);
									if (null != value) {
										MemberValue memberValue = createMemberValue(
												classFile.getConstPool(), pool.get(member.getReturnType().getName()), value);
										ja.addMemberValue(member.getName(), memberValue);
									}
								}
							}
							attribute.addAnnotation(ja);
						}
					}
					String fieldName = method.getName() + "Argument" + i;
					CtField ctField = CtField.make("public " + type.getCanonicalName() + " " + fieldName + ";", pool.getCtClass(parameterClassName));
					ctField.getFieldInfo().addAttribute(attribute);
					ctClass.addField(ctField);
				}
				parameterClass = ctClass.toClass(clazz.getClassLoader(), null);
			}
			Object parameterBean = parameterClass.newInstance();
			for (int i = 0; i < args.length; i++) {
				Field field = parameterClass.getField(method.getName() + "Argument" + i);
				field.set(parameterBean, args[i]);
			}
			return parameterBean;
		} catch (Throwable e) {
			log.warn(e.getMessage(), e);
			return null;
		}
	}

	private static String generateMethodParameterClassName(Class<?> clazz, Method method) {
		StringBuilder builder = new StringBuilder().append(clazz.getName())
				.append("_")
				.append(toUpperMethoName(method.getName()))
				.append("Parameter");

		Class<?>[] parameterTypes = method.getParameterTypes();
		for (Class<?> parameterType : parameterTypes) {
			builder.append("_").append(parameterType.getName());
		}

		return builder.toString();
	}

	private static boolean hasConstraintParameter(Method method) {
		Annotation[][] parameterAnnotations = method.getParameterAnnotations();
		if (parameterAnnotations != null && parameterAnnotations.length > 0) {
			for (Annotation[] annotations : parameterAnnotations) {
				for (Annotation annotation : annotations) {
					if (annotation.annotationType().isAnnotationPresent(Constraint.class)) {
						return true;
					}
				}
			}
		}
		return false;
	}

	private static String toUpperMethoName(String methodName) {
		return methodName.substring(0, 1).toUpperCase() + methodName.substring(1);
	}

	private static MemberValue createMemberValue(ConstPool cp, CtClass type, Object value) throws NotFoundException {
		MemberValue memberValue = javassist.bytecode.annotation.Annotation.createMemberValue(cp, type);
		if (memberValue instanceof BooleanMemberValue)
			((BooleanMemberValue) memberValue).setValue((Boolean) value);
		else if (memberValue instanceof ByteMemberValue)
			((ByteMemberValue) memberValue).setValue((Byte) value);
		else if (memberValue instanceof CharMemberValue)
			((CharMemberValue) memberValue).setValue((Character) value);
		else if (memberValue instanceof ShortMemberValue)
			((ShortMemberValue) memberValue).setValue((Short) value);
		else if (memberValue instanceof IntegerMemberValue)
			((IntegerMemberValue) memberValue).setValue((Integer) value);
		else if (memberValue instanceof LongMemberValue)
			((LongMemberValue) memberValue).setValue((Long) value);
		else if (memberValue instanceof FloatMemberValue)
			((FloatMemberValue) memberValue).setValue((Float) value);
		else if (memberValue instanceof DoubleMemberValue)
			((DoubleMemberValue) memberValue).setValue((Double) value);
		else if (memberValue instanceof ClassMemberValue)
			((ClassMemberValue) memberValue).setValue(((Class<?>) value).getName());
		else if (memberValue instanceof StringMemberValue)
			((StringMemberValue) memberValue).setValue((String) value);
		else if (memberValue instanceof EnumMemberValue)
			((EnumMemberValue) memberValue).setValue(((Enum<?>) value).name());
			/* else if (memberValue instanceof AnnotationMemberValue) */
		else if (memberValue instanceof ArrayMemberValue) {
			CtClass arrayType = type.getComponentType();
			int len = Array.getLength(value);
			MemberValue[] members = new MemberValue[len];
			for (int i = 0; i < len; i++) {
				members[i] = createMemberValue(cp, arrayType, Array.get(value, i));
			}
			((ArrayMemberValue) memberValue).setValue(members);
		}
		return memberValue;
	}

	private void doValidate(Class<?> clazz, Method method, Object[] arguments) {
		List<ConstraintViolation<?>> violations = new LinkedList<>();
		Object parameterBean = getMethodParameterBean(clazz, method, arguments);
		if (parameterBean != null) {
			violations.addAll(validator.validate(parameterBean));
		}
		for (Object arg : arguments) {
			validate(violations, arg);
		}
		if (!violations.isEmpty()) {
			violations = violations.stream()
					.sorted(Comparator.comparing(s -> s.getPropertyPath().toString())).collect(Collectors.toList());
			ConstraintViolation c = violations.get(0);
			log.error("校验方法:{}错误, 属性:{}, 值:{}, 异常:{}", method.getName(), c.getPropertyPath(),
					c.getInvalidValue(), c.getMessage());
			throw new BaseException(c.getMessage());
		}
	}

	private void validate(List<ConstraintViolation<?>> violations, Object arg) {
		if (arg != null && !isPrimitives(arg.getClass())) {
			if (Object[].class.isInstance(arg)) {
				for (Object item : (Object[]) arg) {
					validate(violations, item);
				}
			} else if (Collection.class.isInstance(arg)) {
				for (Object item : (Collection<?>) arg) {
					validate(violations, item);
				}
			} else if (Map.class.isInstance(arg)) {
				for (Map.Entry<?, ?> entry : ((Map<?, ?>) arg).entrySet()) {
					validate(violations, entry.getKey());
					validate(violations, entry.getValue());
				}
			} else {
				violations.addAll(validator.validate(arg));
			}
		}
	}
}
